138
Applications in Natural Language Processing
FIGURE 5.13
Structure of BinaryBERT-based BEBERT. The dashed lines denoted with A, B, and C
represent combining ensemble with different KD strategies.
5.8
BEBERT: Efficient and Robust Binary Ensemble BERT
On the basis of BinaryBERT, Tian et al.[222] proposed to employ ensemble learning on
binary BERT models, yielding Binary Ensemble BERT (BEBERT). Figure 5.13 shows the
architecture of BEBERT based on BinaryBERT [6]. During the training process, BEBERT
updates the sample weights of the training dataset in each iteration, focusing on the wrongly
predicted elements. When using knowledge distillation (KD), the forward propagation is
performed with the full-precision teacher and the binary student. Then the gradient of dis-
tillation loss is computed to update the weights of the ternary student during backward
propagation (BP). After that, the parameters are binarized. The training process of BE-
BERT based on BiBERT is similar to that based on BinaryBERT, except that BiBERT is
quantized from a full-precision student and distilled by the DMD method[195]. Note that the
original two-stage KD [106] contains distillation for Transformer layers and the prediction
layer, introducing extra forward and backward propagation steps in training. Therefore, the
authors proposed distilling the prediction layer or removing the KD procedures to reduce
the training costs.
In detail, the authors used AdaBoost [67] to integrate multiple binary BERTs to build
BEBERT. AdaBoost is a popular ensemble learning method mainly collects the results
from multiple weak learners to decrease the prediction bias. The AdaBoost-based BEBERT
takes as input a training set S of m examples (x1, y1), ..., (xm, ym), where yj ∈Y represents
the label of j-th sample. Afterward, the boosting algorithm calls the binary BERT to
train for N rounds, generating a binary model in each round. In the i-th round, AdaBoost
provides the training set with a distribution Di as the sample weight; The initial distribution
D1 is uniform over S, so D1(i) = 1/m for all i. And then, the BERT training algorithm
computes a classifier hi (or hS
i when KD is employed), focusing on minimizing the error
ei = Pj∼Di(hi(xj) ̸= yj). At last, the booster combines the weak hypotheses into a single
final hypothesis H ←ΣN
i=1αihi(xi).